from .Main2DNet import *
from .Main3DNet import *
from .LightNet import *
from .Resnet3d import *
from .Mictresnet import *
from .C3DNet import *


def get_classification_model(name, **kwargs):
    models = {
        'main3dnet': get_net_3d,
        'main2dnet': get_net_2d,
        'lightnet': get_net_light,
        'resnet3d': get_resnet3d,
        'mictresnet': get_mictresnet,
        'c3dnet': get_net_c3d
    }
    return models[name.lower()](**kwargs)
